import pickle
import torch
from typing import Dict, Tuple, List
import numpy as np
class Dataset(object):
    def __init__(self, path):
        with open(path,"rb") as file:
            (ent_train,ent_valid,ent_test,att_train,att_valid,att_test,rhs_dict,lhs_dict,all_entity,all_relation,all_att,each_att_entity_train,each_att_entity_test)=pickle.load(file)
        self.data_ent={}
        self.data_ent["train"]=ent_train
        self.data_ent["valid"]=ent_valid
        self.data_ent["test"]=ent_test
        
        self.data_att={}
        self.data_att["train"]=att_train
        self.data_att["valid"]=att_valid
        self.data_att["test"]=att_test
        
        self.lhs_dict=lhs_dict
        self.rhs_dict=rhs_dict
        
        self.entity_num=len(all_entity)
        self.relation_num=len(all_relation)
        self.att_num=len(all_att)*2
        
        self.each_att_entity_train=each_att_entity_train
        self.each_att_entity_test=each_att_entity_test
    
    def compute_mae(self,value,predicted_value):
        distance=torch.abs(value-predicted_value)
        return distance
    def compute_mse(self,value,predicted_value):
        distance=(value-predicted_value)**2
        return distance
    
    
    def att_eval(self,device,model):
        
        all_att_hits10_train=[]
        all_att_hits10_test=[]
        all_att_mrr_train=[]
        all_att_mrr_test=[]
        for i in range(int(self.att_num/2)):
            train_entity=torch.tensor(list(self.each_att_entity_train[2*i]))
            test_entity=torch.tensor(list(self.each_att_entity_test[2*i]))
            exists_score_norm=model.score_attribution_exists(torch.tensor(2*i).to(device))  #所有实体有该属性的得分
            origin_score=exists_score_norm.clone()
            not_answer_score=exists_score_norm.clone()
            
            not_answer_score[train_entity]=-1e6
            not_answer_score[test_entity]=-1e6
            
            ##训练集中答案的排名
            rank_train=((origin_score[train_entity].unsqueeze(1)-not_answer_score)<0).sum(dim=-1) + 1
            rank_test=((origin_score[test_entity].unsqueeze(1)-not_answer_score)<0).sum(dim=-1) + 1
            
            rank_train = rank_train.float()
            rank_test = rank_test.float()
            
            hit_at_10_train = torch.mean((rank_train < 10.5).double()).cpu().numpy().item()
            hit_at_10_test = torch.mean((rank_test < 10.5).double()).cpu().numpy().item()
            mrr_train = torch.mean(torch.reciprocal(rank_train)).cpu().numpy().item()
            mrr_test = torch.mean(torch.reciprocal(rank_test)).cpu().numpy().item()
            
            all_att_hits10_train.append(hit_at_10_train)
            all_att_hits10_test.append(hit_at_10_test)
            all_att_mrr_train.append(mrr_train)
            all_att_mrr_test.append(mrr_test)
        return np.mean(all_att_hits10_train),np.mean(all_att_hits10_test),np.mean(all_att_mrr_train),np.mean(all_att_mrr_test)
    
    
    def num_eval(self,device,model,split,n_query,eval_batch=500):
        if split=="train":
            permutation = torch.randperm(len(self.data_att[split]))[:n_query]
            examples=torch.tensor(self.data_att[split])[permutation].to(device)
        else:
            examples=torch.tensor(self.data_att[split]).to(device)
        
        outcome_mae=torch.zeros(examples.shape[0])
        outcome_mse=torch.zeros(examples.shape[0])
        with torch.no_grad():
            begin=0
            end=examples.shape[0]
            while begin<end:
                example_batch=examples[begin:min(begin+eval_batch,end),:]
                head_entity=example_batch[:,0].int()
                attribution=example_batch[:,1].int()
                value=example_batch[:,2]
                predicted_value=model.predict_for_test(head_entity,attribution)
                distance_for_mae=self.compute_mae(value,predicted_value)
                distance_for_mse=self.compute_mse(value,predicted_value)
                
                outcome_mae[begin:min(begin+eval_batch,end)]=distance_for_mae
                outcome_mse[begin:min(begin+eval_batch,end)]=distance_for_mse
                begin+=eval_batch
        mae=torch.mean(outcome_mae)
        mse=torch.mean(outcome_mse)
        rmse=torch.sqrt(mse)
        outcome=(mae,mse,rmse)
        return outcome

class EntityDataset(Dataset):
    """
    Implemenation of CQD not using queries, but triples.
    """

    def __init__(self,path,device,mode):
        # queries is a list of (query, query_structure) pairs
        super(EntityDataset,self).__init__(path)
        self.mode=mode
        self.device=device
        if self.mode=="train":
            self.triples=self.data_ent["train"]
            self.len=len(self.triples)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        query = torch.tensor(self.triples[idx][:2]).to(self.device)
        positive_sample=torch.tensor(self.triples[idx][2]).unsqueeze(0).to(self.device)
        
        negative_sample_list = []
        negative_sample_size = 0
        while negative_sample_size<10:
            negative_sample = np.random.randint(self.entity_num, size=10*2)
            true_answers=self.rhs_dict[(int(query[0]),int(query[1]))]
            mask = np.in1d(
                negative_sample,
                np.array(true_answers),
                assume_unique=True,
                invert=True
            )
            negative_sample = negative_sample[mask]
            negative_sample_list.extend(list(negative_sample))
            negative_sample_size += negative_sample.size
        
        negative_sample=torch.tensor(negative_sample_list[:10]).to(self.device)
        return positive_sample, negative_sample, query

    @staticmethod
    def collate_fn(data):
        positive_sample = torch.cat([_[0] for _ in data], dim=0)
        negative_sample = torch.stack([_[1] for _ in data], dim=0)
        query = torch.stack([_[2] for _ in data],dim=0)
        
        data_entity = {
            "batch_h": query[:, 0].repeat(1+negative_sample.shape[-1]),
            "batch_r": query[:, 1].repeat(1+negative_sample.shape[-1]),
            "batch_t": torch.cat((positive_sample, torch.flatten(negative_sample.t())))
        }
        return data_entity



class AttributeDataset(Dataset):
    # Dataset for 1ap queries only
    def __init__(self,path,device,mode):
        # queries is a list of (query, query_structure) pairs
        super(AttributeDataset,self).__init__(path)
        self.device=device
        self.mode=mode
        if self.mode=="train":
            self.triples=self.data_att["train"]
            self.len=len(self.triples)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        query = self.triples[idx]
        e, a, v = [], [], []
        e.append(query[0])
        a.append(query[1])
        v.append(query[2])
        return torch.LongTensor(e).to(self.device), torch.LongTensor(a).to(self.device), torch.FloatTensor(v).to(self.device)
    
    @staticmethod
    def collate_fn(data):
        batch_e = torch.cat([_[0] for _ in data], dim=0)
        batch_a = torch.cat([_[1] for _ in data], dim=0)
        batch_v = torch.cat([_[2] for _ in data], dim=0)
        data_att={
            "batch_e":batch_e,
            "batch_a":batch_a,
            "batch_v":batch_v
        }
        return data_att